{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  Sliceguard – Find critical data segments in your data (fast)\n",
    "## Structured Data Walkthrough"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "🕹️ **[Interactive Demo](https://huggingface.co/spaces/renumics/sliceguard-structured-data)**\n",
    "\n",
    "Sliceguard is a python library for quickly finding **critical data slices** like outliers, errors, or biases. It works on **structured** and **unstructured** data.\n",
    "\n",
    "This notebook showcases especially the **structured** data case, so if you have unstructured data like images or audio instead have a look at **[guide for unstructured data](./quickstart_unstructured_data.ipynb)** instead\n",
    "\n",
    "It is interesting for you if you want to do the following:\n",
    "1. Find **performance issues** of your machine learning model.\n",
    "2. Find **anomalies and inconsistencies** in your data.\n",
    "3. Quickly **explore** your data using an interactive report to generate **insights**."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To run this notebook install and import sliceguard:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -U sliceguard[automl]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sliceguard import SliceGuard\n",
    "from sliceguard.data import from_huggingface\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.preprocessing import StandardScaler"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now download the demo dataset from the huggingface hub:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = from_huggingface(\"mstz/wine\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the feature names\n",
    "feature_names = [\"fixed_acidity\", \"volatile_acidity\", \"citric_acid\", \"residual_sugar\",\n",
    "                \"chlorides\", \"free_sulfur_dioxide\", \"total_sulfur_dioxide\", \"density\", \"pH\",\n",
    "                \"sulphates\", \"alcohol\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>fixed_acidity</th>\n",
       "      <th>volatile_acidity</th>\n",
       "      <th>citric_acid</th>\n",
       "      <th>residual_sugar</th>\n",
       "      <th>chlorides</th>\n",
       "      <th>free_sulfur_dioxide</th>\n",
       "      <th>total_sulfur_dioxide</th>\n",
       "      <th>density</th>\n",
       "      <th>pH</th>\n",
       "      <th>sulphates</th>\n",
       "      <th>alcohol</th>\n",
       "      <th>quality</th>\n",
       "      <th>is_red</th>\n",
       "      <th>split</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>7.4</td>\n",
       "      <td>0.70</td>\n",
       "      <td>0.00</td>\n",
       "      <td>1.9</td>\n",
       "      <td>0.076</td>\n",
       "      <td>11.0</td>\n",
       "      <td>34.0</td>\n",
       "      <td>0.99780</td>\n",
       "      <td>3.51</td>\n",
       "      <td>0.56</td>\n",
       "      <td>9.4</td>\n",
       "      <td>5</td>\n",
       "      <td>red</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>7.8</td>\n",
       "      <td>0.88</td>\n",
       "      <td>0.00</td>\n",
       "      <td>2.6</td>\n",
       "      <td>0.098</td>\n",
       "      <td>25.0</td>\n",
       "      <td>67.0</td>\n",
       "      <td>0.99680</td>\n",
       "      <td>3.20</td>\n",
       "      <td>0.68</td>\n",
       "      <td>9.8</td>\n",
       "      <td>5</td>\n",
       "      <td>red</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>7.8</td>\n",
       "      <td>0.76</td>\n",
       "      <td>0.04</td>\n",
       "      <td>2.3</td>\n",
       "      <td>0.092</td>\n",
       "      <td>15.0</td>\n",
       "      <td>54.0</td>\n",
       "      <td>0.99700</td>\n",
       "      <td>3.26</td>\n",
       "      <td>0.65</td>\n",
       "      <td>9.8</td>\n",
       "      <td>5</td>\n",
       "      <td>red</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>11.2</td>\n",
       "      <td>0.28</td>\n",
       "      <td>0.56</td>\n",
       "      <td>1.9</td>\n",
       "      <td>0.075</td>\n",
       "      <td>17.0</td>\n",
       "      <td>60.0</td>\n",
       "      <td>0.99800</td>\n",
       "      <td>3.16</td>\n",
       "      <td>0.58</td>\n",
       "      <td>9.8</td>\n",
       "      <td>6</td>\n",
       "      <td>red</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>7.4</td>\n",
       "      <td>0.70</td>\n",
       "      <td>0.00</td>\n",
       "      <td>1.9</td>\n",
       "      <td>0.076</td>\n",
       "      <td>11.0</td>\n",
       "      <td>34.0</td>\n",
       "      <td>0.99780</td>\n",
       "      <td>3.51</td>\n",
       "      <td>0.56</td>\n",
       "      <td>9.4</td>\n",
       "      <td>5</td>\n",
       "      <td>red</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6492</th>\n",
       "      <td>6.2</td>\n",
       "      <td>0.21</td>\n",
       "      <td>0.29</td>\n",
       "      <td>1.6</td>\n",
       "      <td>0.039</td>\n",
       "      <td>24.0</td>\n",
       "      <td>92.0</td>\n",
       "      <td>0.99114</td>\n",
       "      <td>3.27</td>\n",
       "      <td>0.50</td>\n",
       "      <td>11.2</td>\n",
       "      <td>6</td>\n",
       "      <td>white</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6493</th>\n",
       "      <td>6.6</td>\n",
       "      <td>0.32</td>\n",
       "      <td>0.36</td>\n",
       "      <td>8.0</td>\n",
       "      <td>0.047</td>\n",
       "      <td>57.0</td>\n",
       "      <td>168.0</td>\n",
       "      <td>0.99490</td>\n",
       "      <td>3.15</td>\n",
       "      <td>0.46</td>\n",
       "      <td>9.6</td>\n",
       "      <td>5</td>\n",
       "      <td>white</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6494</th>\n",
       "      <td>6.5</td>\n",
       "      <td>0.24</td>\n",
       "      <td>0.19</td>\n",
       "      <td>1.2</td>\n",
       "      <td>0.041</td>\n",
       "      <td>30.0</td>\n",
       "      <td>111.0</td>\n",
       "      <td>0.99254</td>\n",
       "      <td>2.99</td>\n",
       "      <td>0.46</td>\n",
       "      <td>9.4</td>\n",
       "      <td>6</td>\n",
       "      <td>white</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6495</th>\n",
       "      <td>5.5</td>\n",
       "      <td>0.29</td>\n",
       "      <td>0.30</td>\n",
       "      <td>1.1</td>\n",
       "      <td>0.022</td>\n",
       "      <td>20.0</td>\n",
       "      <td>110.0</td>\n",
       "      <td>0.98869</td>\n",
       "      <td>3.34</td>\n",
       "      <td>0.38</td>\n",
       "      <td>12.8</td>\n",
       "      <td>7</td>\n",
       "      <td>white</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6496</th>\n",
       "      <td>6.0</td>\n",
       "      <td>0.21</td>\n",
       "      <td>0.38</td>\n",
       "      <td>0.8</td>\n",
       "      <td>0.020</td>\n",
       "      <td>22.0</td>\n",
       "      <td>98.0</td>\n",
       "      <td>0.98941</td>\n",
       "      <td>3.26</td>\n",
       "      <td>0.32</td>\n",
       "      <td>11.8</td>\n",
       "      <td>6</td>\n",
       "      <td>white</td>\n",
       "      <td>train</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>6497 rows × 14 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      fixed_acidity  volatile_acidity  citric_acid  residual_sugar  chlorides  \\\n",
       "0               7.4              0.70         0.00             1.9      0.076   \n",
       "1               7.8              0.88         0.00             2.6      0.098   \n",
       "2               7.8              0.76         0.04             2.3      0.092   \n",
       "3              11.2              0.28         0.56             1.9      0.075   \n",
       "4               7.4              0.70         0.00             1.9      0.076   \n",
       "...             ...               ...          ...             ...        ...   \n",
       "6492            6.2              0.21         0.29             1.6      0.039   \n",
       "6493            6.6              0.32         0.36             8.0      0.047   \n",
       "6494            6.5              0.24         0.19             1.2      0.041   \n",
       "6495            5.5              0.29         0.30             1.1      0.022   \n",
       "6496            6.0              0.21         0.38             0.8      0.020   \n",
       "\n",
       "      free_sulfur_dioxide  total_sulfur_dioxide  density    pH  sulphates  \\\n",
       "0                    11.0                  34.0  0.99780  3.51       0.56   \n",
       "1                    25.0                  67.0  0.99680  3.20       0.68   \n",
       "2                    15.0                  54.0  0.99700  3.26       0.65   \n",
       "3                    17.0                  60.0  0.99800  3.16       0.58   \n",
       "4                    11.0                  34.0  0.99780  3.51       0.56   \n",
       "...                   ...                   ...      ...   ...        ...   \n",
       "6492                 24.0                  92.0  0.99114  3.27       0.50   \n",
       "6493                 57.0                 168.0  0.99490  3.15       0.46   \n",
       "6494                 30.0                 111.0  0.99254  2.99       0.46   \n",
       "6495                 20.0                 110.0  0.98869  3.34       0.38   \n",
       "6496                 22.0                  98.0  0.98941  3.26       0.32   \n",
       "\n",
       "      alcohol  quality is_red  split  \n",
       "0         9.4        5    red  train  \n",
       "1         9.8        5    red  train  \n",
       "2         9.8        5    red  train  \n",
       "3         9.8        6    red  train  \n",
       "4         9.4        5    red  train  \n",
       "...       ...      ...    ...    ...  \n",
       "6492     11.2        6  white  train  \n",
       "6493      9.6        5  white  train  \n",
       "6494      9.4        6  white  train  \n",
       "6495     12.8        7  white  train  \n",
       "6496     11.8        6  white  train  \n",
       "\n",
       "[6497 rows x 14 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Show dataframe\n",
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check for data slices that are particulary different (Outliers/Errors in the data)\n",
    "Here sliceguard will train an **outlier detection** model to highlight data points that are especially different from the rest. Note that we here run the outlier detection only on the class \"red\". You will easily see there is a bunch of large outliers e.g. with extremely high residual sugar values.\n",
    "\n",
    "You can use the **report feature** that uses [Renumics Spotlight](https://github.com/Renumics/spotlight) for visualization to dig into the reasons why a cluster is considered an outlier. E.g. use multiple **histograms** to see what is characteristic for a datapoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sg = SliceGuard()\n",
    "issues = sg.find_issues(df[df[\"is_red\"] == \"red\"], features=feature_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = sg.report()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check for data slices where models are prone to fail (hard samples, inconsistencies)\n",
    "Here sliceguard will **train a classification model** and check for data slices where the accuracy score is particulary bad. You will realize that this will show you a bunch of wines that are easily confused using there chemical values, e.g. red wines with particulary high residual sugar levels or especially low chloride values.\n",
    "\n",
    "Note that the dataset is relatively easy, so it won't find that many issues as in more complex datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train the model and predict on the same data (of course in practice you will want to split your data!!!)\n",
    "# This is only for showing the principle\n",
    "sg = SliceGuard()\n",
    "issues = sg.find_issues(df,\n",
    "                        features=feature_names,\n",
    "                        y=\"is_red\",\n",
    "                        metric=accuracy_score,\n",
    "                        automl_task=\"classification\"\n",
    "                       ) # also try out drop_reference=\"parent\" for more class-specific results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = sg.report()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Check for weaknesses of your own model (...and hard samples + inconsistencies)\n",
    "This shows how to pass your **own model predictions** into sliceguard to find slices that are performing badly according to a supplied metric function. This allows you to uncover **inconsistencies** and samples that are **hard to learn** in no time!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train the model and predict on the same data (of course in practice you will want to split your data!!!)\n",
    "# This is only for showing the principle\n",
    "X = StandardScaler().fit_transform(df[feature_names].values)\n",
    "y = df[\"is_red\"]\n",
    "\n",
    "clf = SVC()\n",
    "clf.fit(X, y)\n",
    "df[\"predictions\"] = clf.predict(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Pass the predictions to sliceguard and uncover hard samples and inconsistencies.\n",
    "sg = SliceGuard()\n",
    "issues = sg.find_issues(df,\n",
    "                        features=feature_names,\n",
    "                        y=\"is_red\",\n",
    "                        y_pred=\"predictions\",\n",
    "                        metric=accuracy_score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = sg.report()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
